import cv2
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from nanodet.util import (
    bbox2distance,
    distance2bbox,
    images_to_levels,
    multi_apply,
    overlay_bbox_cv,
)

from ...data.transform.warp import warp_boxes
from ..loss.gfocal_loss import DistributionFocalLoss, QualityFocalLoss
from ..loss.iou_loss import GIoULoss, bbox_overlaps
from ..module.conv import ConvModule
from ..module.init_weights import normal_init
from ..module.nms import multiclass_nms
from ..module.scale import Scale
from .assigner.atss_assigner import ATSSAssigner


def reduce_mean(tensor):
    if not (dist.is_available() and dist.is_initialized()):  # 如果不支持分布式训练 直接输出
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.true_divide(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor


class Integral(nn.Module):
    """A fixed layer for calculating integral result from distribution.
    This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
    P(y_i) denotes the softmax vector that represents the discrete distribution
    y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
    Args:
        reg_max (int): The maximal value of the discrete set. Default: 16. You
            may want to reset it according to your new dataset or related
            settings.
    """

    def __init__(self, reg_max=16):
        super(Integral, self).__init__()
        self.reg_max = reg_max  # 7
        self.register_buffer(
            "project", torch.linspace(0, self.reg_max, self.reg_max + 1)  # 返回一维 tensor = [0, 1, 2, 3, ... reg_max]
        )

    def forward(self, x):
        """Forward feature from the regression head to get integral result of
        bounding box location.
        Args:
            x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
                n is self.reg_max.
        Returns:
            x (Tensor): Integral result of box locations, i.e., distance
                offsets from the box center in four directions, shape (N, 4).
        """
        x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)  # softmax 之后，数据就是 (0, 1)之间了
        x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)  # 与 self.project 做矩阵相乘，返回 (0, reg_max) 之间的数
        return x


class GFLHead(nn.Module):
    """Generalized Focal Loss: Learning Qualified and Distributed Bounding
    Boxes for Dense Object Detection.

    GFL head structure is similar with ATSS, however GFL uses
    1) joint representation for classification and localization quality, and
    2) flexible General distribution for bounding box locations,
    which are supervised by
    Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively

    https://arxiv.org/abs/2006.04388

    :param num_classes: Number of categories excluding the background category.
    :param loss: Config of all loss functions.
    :param input_channel: Number of channels in the input feature map.
    :param feat_channels: Number of conv layers in cls and reg tower. Default: 4.
    :param stacked_convs: Number of conv layers in cls and reg tower. Default: 4.
    :param octave_base_scale: Scale factor of grid cells.
    :param strides: Down sample strides of all level feature map
    :param conv_cfg: Dictionary to construct and config conv layer. Default: None.
    :param norm_cfg: Dictionary to construct and config norm layer.
    :param reg_max: Max value of integral set :math: `{0, ..., reg_max}`
                    in QFL setting. Default: 16.
    :param kwargs:
    """

    def __init__(
        self,
        num_classes,
        loss,
        input_channel,
        feat_channels=256,
        stacked_convs=4,
        octave_base_scale=4,
        strides=[8, 16, 32],
        conv_cfg=None,
        norm_cfg=dict(type="GN", num_groups=32, requires_grad=True),
        reg_max=16,
        **kwargs
    ):
        super(GFLHead, self).__init__()
        self.num_classes = num_classes  # coco, 80
        self.in_channels = input_channel  # 96
        self.feat_channels = feat_channels  # 96
        self.stacked_convs = stacked_convs  # 2
        self.grid_cell_scale = octave_base_scale  # 5
        self.strides = strides  # [8, 16, 32]
        self.reg_max = reg_max  # 7

        self.loss_cfg = loss  # 损失函数类型：QualityFocalLoss + DistributionFocalLoss + GIoULoss
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg  # BN
        self.use_sigmoid = self.loss_cfg.loss_qfl.use_sigmoid  # True
        if self.use_sigmoid:
            self.cls_out_channels = num_classes  # 使用了 sigmoid , value = 80
        else:
            self.cls_out_channels = num_classes + 1

        self.assigner = ATSSAssigner(topk=9)
        self.distribution_project = Integral(self.reg_max)

        self.loss_qfl = QualityFocalLoss(
            use_sigmoid=self.use_sigmoid,
            beta=self.loss_cfg.loss_qfl.beta,  # 2.0
            loss_weight=self.loss_cfg.loss_qfl.loss_weight,  # 1.0
        )
        self.loss_dfl = DistributionFocalLoss(
            loss_weight=self.loss_cfg.loss_dfl.loss_weight  # loss_weight=0.25
        )
        self.loss_bbox = GIoULoss(loss_weight=self.loss_cfg.loss_bbox.loss_weight)  # loss_weight = 2.0
        self._init_layers()
        self.init_weights()

    def _init_layers(self):
        self.relu = nn.ReLU(inplace=True)
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            self.cls_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                )
            )
            self.reg_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                )
            )
        self.gfl_cls = nn.Conv2d(
            self.feat_channels, self.cls_out_channels, 3, padding=1
        )
        self.gfl_reg = nn.Conv2d(
            self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1
        )
        self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])

    def init_weights(self):
        for m in self.cls_convs:
            normal_init(m.conv, std=0.01)
        for m in self.reg_convs:
            normal_init(m.conv, std=0.01)
        bias_cls = -4.595
        normal_init(self.gfl_cls, std=0.01, bias=bias_cls)
        normal_init(self.gfl_reg, std=0.01)

    def forward(self, feats):
        return multi_apply(self.forward_single, feats, self.scales)

    def forward_single(self, x, scale):
        cls_feat = x
        reg_feat = x
        for cls_conv in self.cls_convs:
            cls_feat = cls_conv(cls_feat)
        for reg_conv in self.reg_convs:
            reg_feat = reg_conv(reg_feat)
        cls_score = self.gfl_cls(cls_feat)
        bbox_pred = scale(self.gfl_reg(reg_feat)).float()
        return cls_score, bbox_pred

    def loss(self, preds, gt_meta):
        # 两个输出都是列表，里面储存着三个输出。前者的输出元素 shape 是(batch, 80, h, w), 后者输出元素shape是(batch, 32, h, w).
        # h w 顺序是降序
        cls_scores, bbox_preds = preds
        print(cls_scores[0].shape, cls_scores[1].shape, cls_scores[2].shape)
        print(bbox_preds[0].shape, bbox_preds[1].shape, bbox_preds[2].shape)
        
        batch_size = cls_scores[0].shape[0]
        device = cls_scores[0].device
        gt_bboxes = gt_meta["gt_bboxes"]  # bbox坐标 shape = (N, 4), N 是一张图片中的目标数目。坐标没有归一化
        gt_labels = gt_meta["gt_labels"]  # 目标类别 shape = (N, )，外面用个列表封装了一下，上同
        gt_bboxes_ignore = None
        print(gt_bboxes[0].shape, gt_labels[0].shape)
        #print(gt_bboxes, gt_labels)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]  # 返回的是三层 feature map 的宽高
        # 正负样本定义
        cls_reg_targets = self.target_assign(
            batch_size,
            featmap_sizes,
            gt_bboxes,
            gt_bboxes_ignore,
            gt_labels,
            device=device,
        )
        if cls_reg_targets is None:
            return None

        (
            grid_cells_list,  # [(batch, 1120, 4), (batch, 280, 4), (batch, 70, 4)] 
            labels_list,  # [(batch, 1120), (batch, 280), (batch, 70)] 
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_pos,  # 一个 batch中 所有 正样本anchor 数目
            num_total_neg,  # 一个 batch中 所有 负样本anchor 数目
        ) = cls_reg_targets
        #print(num_total_pos)
        num_total_samples = reduce_mean(torch.tensor(num_total_pos).to(device)).item()
        num_total_samples = max(num_total_samples, 1.0)
        #print(num_total_samples)
        # 计算loss。调用 loss_single 3 次，每次都对应一个输出层
        losses_qfl, losses_bbox, losses_dfl, avg_factor = multi_apply(
            self.loss_single,
            grid_cells_list,
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            self.strides,
            num_total_samples=num_total_samples,
        )
        exit(0)
        avg_factor = sum(avg_factor)
        avg_factor = reduce_mean(avg_factor).item()
        if avg_factor <= 0:
            loss_qfl = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(
                device
            )
            loss_bbox = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(
                device
            )
            loss_dfl = torch.tensor(0, dtype=torch.float32, requires_grad=True).to(
                device
            )
        else:
            losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
            losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))

            loss_qfl = sum(losses_qfl)
            loss_bbox = sum(losses_bbox)
            loss_dfl = sum(losses_dfl)

        loss = loss_qfl + loss_bbox + loss_dfl
        loss_states = dict(loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)

        return loss, loss_states
    # 该函数被调用三次，每次输入输出都不一样，对应三个输出层，即三个层分别算损失
    def loss_single(
        self,
        grid_cells,  # anchor 左上右下坐标（image坐标系） (batch, 1120, 4), (batch, 280, 4), (batch, 70, 4)
        cls_score,  # 类别输出 [batch, 80, 28, 40] [batch, 80, 14, 20] [batch, 80, 7, 10]
        bbox_pred,  # bbox输出 [batch, 32, 28, 40] [batch, 32, 14, 20] [batch, 32, 7, 10]
        labels,  # 类别标签 (batch, 1120), (batch, 280), (batch, 70)
        label_weights,  # (batch, 1120), (batch, 280), (batch, 70)
        bbox_targets,  # bbox标签 (batch, 1120, 4), (batch, 280, 4), (batch, 70, 4)
        stride,  # [8, 16, 32]
        num_total_samples,  # int 变量， 一个 batch 中 所有 正样本 anchor 数目
    ):

        grid_cells = grid_cells.reshape(-1, 4)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)  # [batch, 80, h, w] -> [batch, h, w, 80] -> [batch * h * w, 80]
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4 * (self.reg_max + 1))  # [batch, 32, h, w] -> [batch, h, w, 32] -> [batch * h * w, 32]
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        #print("1  ", grid_cells.shape, "  ", cls_score.shape, "  ", bbox_pred.shape, "  ", bbox_targets.shape, "  ", labels.shape)
        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        bg_class_ind = self.num_classes  # labels 数组初始值是 num_classes
        pos_inds = torch.nonzero(
            (labels >= 0) & (labels < bg_class_ind), as_tuple=False
        ).squeeze(1)

        score = label_weights.new_zeros(labels.shape)
        
        if len(pos_inds) > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]  # (pos_num, 4 * (reg_max + 1))
            pos_grid_cells = grid_cells[pos_inds]
            pos_grid_cell_centers = self.grid_cells_to_center(pos_grid_cells) / stride  # (pos_num, 2), anchor 左上右下坐标转为中心坐标，并转为 feature map 坐标系

            weight_targets = cls_score.detach().sigmoid()  # 类别这里是多标签分类（不互斥多分类）
            weight_targets = weight_targets.max(dim=1)[0][pos_inds]  # 计算每一行最大值，再取出正样本对应的 anchor 结果。但对应类别的概率值不一定是最大值，不懂
            pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)  # (pos_num, 4),已经对预测值做了后处理,此时它的元素意义是中心点到四条边的距离(预测值)
            
            pos_decode_bbox_pred = distance2bbox(  # anchor 中心坐标，加/减去预测值(anchor 中心点到 bbox 四条边的距离)，得到了预测的 bbox 左上右下角点坐标
                pos_grid_cell_centers, pos_bbox_pred_corners
            )
            pos_decode_bbox_targets = pos_bbox_targets / stride  # 将bbox坐标由 image 坐标系变换为 feature map 坐标系,其元素值不是中心点相对边的距离，而是绝对坐标
            score[pos_inds] = bbox_overlaps(
                pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True  # 返回 (pos_num, ) 数组，bbox_pred 与其匹配 bbox_target 的 IOU 值
            )
            
            pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)  # (pos_num * 4, (reg_max + 1))
            target_corners = bbox2distance(  # 得到了 anchor 中心点到 gt 四条边的距离  (pos_num, 4) -> (pos_num * 4, )
                pos_grid_cell_centers, pos_decode_bbox_targets, self.reg_max
            ).reshape(-1)

            # regression loss. GIoULoss
            loss_bbox = self.loss_bbox(
                pos_decode_bbox_pred,  # 预测的 bbox 左上右下角点坐标, (pos_num, 4)
                pos_decode_bbox_targets,  # bbox 左上右下角点坐标标签, (pos_num, 4)
                weight=weight_targets,  # (pos_num, ) 这个没用上
                avg_factor=1.0,  # 这个也没用上
            )
            
            # dfl loss
            loss_dfl = self.loss_dfl(
                pred_corners,  # 模型 bbox 分支输出   (pos_num * 4, (reg_max + 1))
                target_corners,  # anchor 中心点到 gt 四条边的距离   (pos_num * 4, )
                weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
                avg_factor=4.0,
            )
        else:  # 如果是负样本，bbox损失，dfl损失 直接置为 0
            loss_bbox = bbox_pred.sum() * 0
            loss_dfl = bbox_pred.sum() * 0
            weight_targets = torch.tensor(0).to(cls_score.device)

        # qfl loss，类别损失不管是否有正样本存在，都做损失
        loss_qfl = self.loss_qfl(
            cls_score,  # 类别原始输出 [batch * h * w, 80]
            (labels, score), # 类别标签 labels.shape = [batch * h * w, ], score 是对应 bbox 输出与 bbox 标签 的 iou 值. [batch * h * w, ]
            weight=label_weights,
            avg_factor=num_total_samples,
        )

        return loss_qfl, loss_bbox, loss_dfl, weight_targets.sum()

    def target_assign(
        self,
        batch_size,
        featmap_sizes,
        gt_bboxes_list,
        gt_bboxes_ignore_list,  # None
        gt_labels_list,
        device,
    ):
        """
        Assign target for a batch of images.
        :param batch_size: num of images in one batch
        :param featmap_sizes: A list of all grid cell boxes in all image
        :param gt_bboxes_list: A list of ground truth boxes in all image
        :param gt_bboxes_ignore_list: A list of all ignored boxes in all image
        :param gt_labels_list: A list of all ground truth label in all image
        :param device: pytorch device
        :return: Assign results of all images.
        """
        # get grid cells of one image 返回列表封装的三个数组 [(28x40, 4), (14x20, 4), (7x10, 4)]
        multi_level_grid_cells = [
            self.get_grid_cells(
                featmap_sizes[i],
                self.grid_cell_scale,  # 5
                stride,
                dtype=torch.float32,
                device=device,
            )
            for i, stride in enumerate(self.strides)  # 8 16 32
        ]
        #print(multi_level_grid_cells[0].shape, multi_level_grid_cells[1].shape, multi_level_grid_cells[2].shape)
        
        mlvl_grid_cells_list = [multi_level_grid_cells for i in range(batch_size)]  # 根据batch 数，将 multi_level_grid_cells 复制成多份

        # pixel cell number of multi-level feature maps ,    [28x40, 14x20, 7x10]
        num_level_cells = [grid_cells.size(0) for grid_cells in mlvl_grid_cells_list[0]]  # 三种 feature map 的宽高之积，组成的列表
        num_level_cells_list = [num_level_cells] * batch_size
        # concat all level cells and to a single tensor
        for i in range(batch_size):
            mlvl_grid_cells_list[i] = torch.cat(mlvl_grid_cells_list[i])  # 将三个level的cell数组合并，得到新数组 shape = (28x40 + 14x20 + 7x10, 4)
        #print(mlvl_grid_cells_list[0].shape)
        
        # compute targets for each image
        if gt_bboxes_ignore_list is None:
            gt_bboxes_ignore_list = [None for _ in range(batch_size)]
        if gt_labels_list is None:
            gt_labels_list = [None for _ in range(batch_size)]
        # target assign on all images, get list of tensors
        # list length = batch size
        # tensor first dim = num of all grid cell, 以下输出的项目都是列表，每个列表里元素个数就是 batch 的大小
        (
            all_grid_cells,  # 列表，每个元素 shape = (anchor_num, 4)，以下类似
            all_labels,  # 列表，每个元素是每张图片正负样本的对应类别 (anchor_num, )
            all_label_weights,  # 列表，每个元素 shape = (anchor_num, )
            all_bbox_targets,  # 列表，每个元素是每张图片正负样本的 box 坐标 (anchor_num, 4)
            all_bbox_weights,  # 列表，每个元素 shape = (anchor_num, 4)
            pos_inds_list,  # 列表，每个元素是每张图片所有正样本的索引
            neg_inds_list,  # 列表，每个元素是每张图片所有负样本的索引
        ) = multi_apply(
            self.target_assign_single_img,  # batch 里有几张图片，就调用该函数几次
            mlvl_grid_cells_list,  # 下面几个列表都含有 batch 个子列表，分别储存各自的数据
            num_level_cells_list,  # 
            gt_bboxes_list,  # 
            gt_bboxes_ignore_list,
            gt_labels_list,
        )
        
        #exit(0)
        # no valid cells
        if any([labels is None for labels in all_labels]):  # 类似于 or , 去掉没有 label 数据的输入
            return None
        # sampled cells of all images
        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])  # 一个 batch 总的正样本数
        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])  # 一个 batch 总的负样本数
        # merge list of targets tensors into one batch then split to multi levels
        mlvl_grid_cells = images_to_levels(all_grid_cells, num_level_cells)  # [(batch, 1120, 4), (batch, 280, 4), (batch, 70, 4)] 
        mlvl_labels = images_to_levels(all_labels, num_level_cells)  # [(batch, 1120), (batch, 280), (batch, 70)]
        mlvl_label_weights = images_to_levels(all_label_weights, num_level_cells)  # 同上
        mlvl_bbox_targets = images_to_levels(all_bbox_targets, num_level_cells)  # 同上
        mlvl_bbox_weights = images_to_levels(all_bbox_weights, num_level_cells)  # 同上
        # 上面五个变量的输出是同一个batch中，同一个输出层标签 组成的列表，列表长度都为 3(因为 3 个输出层)
        return (
            mlvl_grid_cells,
            mlvl_labels,
            mlvl_label_weights,
            mlvl_bbox_targets,
            mlvl_bbox_weights,
            num_total_pos,
            num_total_neg,
        )

    def target_assign_single_img(
        self, grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, gt_labels
    ):
        """
        Using ATSS Assigner to assign target on one image.
        :param grid_cells: Grid cell boxes of all pixels on feature map 
        :param num_level_cells: numbers of grid cells on each level's feature map . value = [1120, 280, 70]
        :param gt_bboxes: Ground truth boxes .  shape = (gt_num, 4)
        :param gt_bboxes_ignore: Ground truths which are ignored
        :param gt_labels: Ground truth labels. gt 的 coco 类别索引。shape = (gt_num, )
        :return: Assign results of a single image
        """
        #print("\nhere11\n")
        device = grid_cells.device
        gt_bboxes = torch.from_numpy(gt_bboxes).to(device)
        gt_labels = torch.from_numpy(gt_labels).to(device)
        # 此时输出层中的 anchor 都以匹配到正/负样本
        assign_result = self.assigner.assign(
            grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, gt_labels
        )
        # 返回所有 anchor 数组里，所有正样本的索引，所有负样本的索引、正样本对应的 box 坐标、所有正样本对应 gt 索引(取值范围为 [0, gt_num - 1])，
        pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample(
            assign_result, gt_bboxes
        )
        # grid_cells.shape = (1470, 4)
        num_cells = grid_cells.shape[0]  # all anchor num 
        bbox_targets = torch.zeros_like(grid_cells)  # 初始化为 0
        bbox_weights = torch.zeros_like(grid_cells)  # 初始化为 0
        labels = grid_cells.new_full((num_cells, ), self.num_classes, dtype=torch.long)  # 初始化为 80，实际里类别最高索引值为 79 .后文可知，80是背景
        label_weights = grid_cells.new_zeros(num_cells, dtype=torch.float)  # 初始化为 0
        print("len(pos_inds) = %d,  len(neg_inds) = %d" % (len(pos_inds), len(neg_inds)))
        if len(pos_inds) > 0:
            pos_bbox_targets = pos_gt_bboxes
            bbox_targets[pos_inds, :] = pos_bbox_targets  # 正样本 box 坐标复制，负样本 box 坐标为 0 
            bbox_weights[pos_inds, :] = 1.0  # 该数组没用上，不用管
            if gt_labels is None:
                # Only rpn gives gt_labels as None
                # Foreground is the first class
                labels[pos_inds] = 0
            else:
                labels[pos_inds] = gt_labels[pos_assigned_gt_inds]  # 正样本类别索引复制，负样本（背景）索引为 80

            label_weights[pos_inds] = 1.0  # 正样本类别权值置为 1
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0  # 不懂, 这不是没用过上吗

        return (
            grid_cells,
            labels,
            label_weights,
            bbox_targets,
            bbox_weights,
            pos_inds,
            neg_inds,
        )

    def sample(self, assign_result, gt_bboxes):
        # 找出所有的正样本 anchor 索引
        pos_inds = (
            torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
            .squeeze(-1)
            .unique()
        )
        # 找出所有的负样本 anchor 索引
        neg_inds = (
            torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
            .squeeze(-1)
            .unique()
        )
        # 所有正样本 anchor 对应的 gt索引，即一幅图里第几个 gt
        pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
        #print(pos_assigned_gt_inds)
        if gt_bboxes.numel() == 0:
            # hack for index error case
            assert pos_assigned_gt_inds.numel() == 0
            pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
        else:
            if len(gt_bboxes.shape) < 2:
                gt_bboxes = gt_bboxes.view(-1, 4)
            pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]  # 根据 gt 索引，把对应的 gt 坐标复制过来，给每个正样本 anchor 配备好相应的 gt 坐标
            #print(pos_gt_bboxes)
        return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds

    def post_process(self, preds, meta):
        cls_scores, bbox_preds = preds
        result_list = self.get_bboxes(cls_scores, bbox_preds, meta)
        det_results = {}
        warp_matrixes = (
            meta["warp_matrix"]
            if isinstance(meta["warp_matrix"], list)
            else meta["warp_matrix"]
        )
        img_heights = (
            meta["img_info"]["height"].cpu().numpy()
            if isinstance(meta["img_info"]["height"], torch.Tensor)
            else meta["img_info"]["height"]
        )
        img_widths = (
            meta["img_info"]["width"].cpu().numpy()
            if isinstance(meta["img_info"]["width"], torch.Tensor)
            else meta["img_info"]["width"]
        )
        img_ids = (
            meta["img_info"]["id"].cpu().numpy()
            if isinstance(meta["img_info"]["id"], torch.Tensor)
            else meta["img_info"]["id"]
        )

        for result, img_width, img_height, img_id, warp_matrix in zip(
            result_list, img_widths, img_heights, img_ids, warp_matrixes
        ):
            det_result = {}
            det_bboxes, det_labels = result
            det_bboxes = det_bboxes.cpu().numpy()
            det_bboxes[:, :4] = warp_boxes(
                det_bboxes[:, :4], np.linalg.inv(warp_matrix), img_width, img_height
            )
            classes = det_labels.cpu().numpy()
            for i in range(self.num_classes):
                inds = classes == i
                det_result[i] = np.concatenate(
                    [
                        det_bboxes[inds, :4].astype(np.float32),
                        det_bboxes[inds, 4:5].astype(np.float32),
                    ],
                    axis=1,
                ).tolist()
            det_results[img_id] = det_result
        return det_results

    def show_result(
        self, img, dets, class_names, score_thres=0.3, show=True, save_path=None
    ):
        result = overlay_bbox_cv(img, dets, class_names, score_thresh=score_thres)
        if show:
            cv2.imshow("det", result)
        return result

    def get_bboxes(self, cls_scores, bbox_preds, img_metas, rescale=False):

        assert len(cls_scores) == len(bbox_preds)
        num_levels = len(cls_scores)
        device = cls_scores[0].device

        input_height, input_width = img_metas["img"].shape[2:]
        input_shape = [input_height, input_width]

        result_list = []
        for img_id in range(cls_scores[0].shape[0]):
            cls_score_list = [cls_scores[i][img_id].detach() for i in range(num_levels)]
            bbox_pred_list = [bbox_preds[i][img_id].detach() for i in range(num_levels)]
            scale_factor = 1
            dets = self.get_bboxes_single(
                cls_score_list,
                bbox_pred_list,
                input_shape,
                scale_factor,
                device,
                rescale,
            )

            result_list.append(dets)
        return result_list

    def get_bboxes_single(
        self, cls_scores, bbox_preds, img_shape, scale_factor, device, rescale=False
    ):
        """
        Decode output tensors to bboxes on one image.
        :param cls_scores: classification prediction tensors of all stages
        :param bbox_preds: regression prediction tensors of all stages
        :param img_shape: shape of input image
        :param scale_factor: scale factor of boxes
        :param device: device of the tensor
        :return: predict boxes and labels
        """
        assert len(cls_scores) == len(bbox_preds)
        mlvl_bboxes = []
        mlvl_scores = []
        for stride, cls_score, bbox_pred in zip(self.strides, cls_scores, bbox_preds):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            featmap_size = cls_score.size()[-2:]
            y, x = self.get_single_level_center_point(
                featmap_size, stride, cls_score.dtype, device, flatten=True
            )
            center_points = torch.stack([x, y], dim=-1)
            scores = (
                cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels).sigmoid()
            )
            bbox_pred = bbox_pred.permute(1, 2, 0)
            bbox_pred = self.distribution_project(bbox_pred) * stride

            nms_pre = 1000
            if scores.shape[0] > nms_pre:
                max_scores, _ = scores.max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                center_points = center_points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]

            bboxes = distance2bbox(center_points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        mlvl_bboxes = torch.cat(mlvl_bboxes)
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)

        mlvl_scores = torch.cat(mlvl_scores)
        # add a dummy background class at the end of all labels
        # same with mmdetection2.0
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)

        det_bboxes, det_labels = multiclass_nms(
            mlvl_bboxes,
            mlvl_scores,
            score_thr=0.05,
            nms_cfg=dict(type="nms", iou_threshold=0.6),
            max_num=100,
        )
        return det_bboxes, det_labels

    def get_single_level_center_point(
        self, featmap_size, stride, dtype, device, flatten=True
    ):
        """
        Generate pixel centers of a single stage feature map.
        :param featmap_size: height and width of the feature map
        :param stride: down sample stride of the feature map
        :param dtype: data type of the tensors
        :param device: device of the tensors
        :param flatten: flatten the x and y tensors
        :return: y and x of the center points
        """
        h, w = featmap_size
        # 先返回feature map坐标下 grid 的中心坐标，再乘以对应的stride，得到原图坐标下grid的中心坐标
        x_range = (torch.arange(w, dtype=dtype, device=device) + 0.5) * stride
        y_range = (torch.arange(h, dtype=dtype, device=device) + 0.5) * stride
        y, x = torch.meshgrid(y_range, x_range)  # 合并成完整的坐标
        if flatten:
            y = y.flatten()  # 拉成一维向量
            x = x.flatten()
        return y, x

    def get_grid_cells(self, featmap_size, scale, stride, dtype, device):  # scale = 5 in test.yml
        """
        Generate grid cells of a feature map for target assignment.
        :param featmap_size: Size of a single level feature map.
        :param scale: Grid cell scale.
        :param stride: Down sample stride of the feature map.
        :param dtype: Data type of the tensors.
        :param device: Device of the tensors.
        :return: Grid_cells xyxy position. Size should be [feat_w * feat_h, 4]
        """
        cell_size = stride * scale  # anchor 的边长
        y, x = self.get_single_level_center_point(
            featmap_size, stride, dtype, device, flatten=True
        )
        grid_cells = torch.stack(
            [
                x - 0.5 * cell_size,  # 在 cell 中心坐标处放了一个方形 anchor，宽为 cell_size
                y - 0.5 * cell_size,
                x + 0.5 * cell_size,
                y + 0.5 * cell_size,
            ],
            dim=-1,
        )
        return grid_cells

    def grid_cells_to_center(self, grid_cells):
        """
        Get center location of each gird cell
        :param grid_cells: grid cells of a feature map
        :return: center points
        """
        cells_cx = (grid_cells[:, 2] + grid_cells[:, 0]) / 2
        cells_cy = (grid_cells[:, 3] + grid_cells[:, 1]) / 2
        return torch.stack([cells_cx, cells_cy], dim=-1)
